package me.seu.demo.utils;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.Map;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * @author xuguoquan
 * @date 2021/9/29 10:40
 * @description 线程池工具类
 */
public class ThreadPoolUtils {

    /**
     * 核心线程数 最大核心线程数 默认队列大小 默认线程名前缀 默认拒绝策略
     */
    private static final int CORE_POOL_SIZE = Runtime.getRuntime().availableProcessors();
    private static final int MAX_POOL_SIZE = 2 * CORE_POOL_SIZE;
    private static final int DEFAULT_QUEUE_CAPACITY = 100;
    private static final int DEFAULT_KEEP_ALIVE_SECONDS = 60;
    private static final String DEFAULT_THREAD_NAME_PREFIX = "college-executor-";
    private static final RejectedExecutionHandler DEFAULT_REJECTED_EXECUTION_HANDLER = new ThreadPoolExecutor.CallerRunsPolicy();

    /**
     * 设置普通的线程池
     *
     * @param corePoolSize             核心线程数
     * @param maxPoolSize              最大线程数
     * @param queueCapacity            队列大小
     * @param keepAliveSeconds         线程存活时间
     * @param threadNamePrefix         线程名前缀
     * @param rejectedExecutionHandler 拒绝策略
     * @return 线程池
     */
    public static ThreadPoolTaskExecutor buildThreadPool(Integer corePoolSize,
                                                         Integer maxPoolSize,
                                                         Integer queueCapacity,
                                                         Integer keepAliveSeconds,
                                                         String threadNamePrefix,
                                                         RejectedExecutionHandler rejectedExecutionHandler) {
        ThreadPoolTaskExecutor executor = setThreadPoolParam(corePoolSize, maxPoolSize, queueCapacity, keepAliveSeconds, threadNamePrefix, rejectedExecutionHandler);
        initialize(executor);
        return executor;
    }

    /**
     * 设置带MDC装饰器的线程池
     *
     * @param corePoolSize             核心线程数
     * @param maxPoolSize              最大线程数
     * @param queueCapacity            队列大小
     * @param keepAliveSeconds         线程存活时间
     * @param threadNamePrefix         线程名前缀
     * @param rejectedExecutionHandler 拒绝策略
     * @return 线程池
     */
    public static ThreadPoolTaskExecutor buildThreadPoolWithMdcTaskDecorator(Integer corePoolSize,
                                                                             Integer maxPoolSize,
                                                                             Integer queueCapacity,
                                                                             Integer keepAliveSeconds,
                                                                             String threadNamePrefix,
                                                                             RejectedExecutionHandler rejectedExecutionHandler) {
        ThreadPoolTaskExecutor executor = setThreadPoolParam(corePoolSize, maxPoolSize, queueCapacity, keepAliveSeconds, threadNamePrefix, rejectedExecutionHandler);
        // 设置装饰器
        setTaskDecorator(executor);
        initialize(executor);
        return executor;
    }

    /**
     * 设置线程池参数
     */
    private static ThreadPoolTaskExecutor setThreadPoolParam(Integer corePoolSize,
                                                             Integer maxPoolSize,
                                                             Integer queueCapacity,
                                                             Integer keepAliveSeconds,
                                                             String threadNamePrefix,
                                                             RejectedExecutionHandler rejectedExecutionHandler) {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(corePoolSize == null ? CORE_POOL_SIZE : corePoolSize);
        executor.setMaxPoolSize(maxPoolSize == null ? MAX_POOL_SIZE : maxPoolSize);
        executor.setQueueCapacity(queueCapacity == null ? DEFAULT_QUEUE_CAPACITY : queueCapacity);
        executor.setKeepAliveSeconds(keepAliveSeconds == null ? DEFAULT_KEEP_ALIVE_SECONDS : keepAliveSeconds);
        executor.setThreadNamePrefix(StringUtils.isBlank(threadNamePrefix) ? DEFAULT_THREAD_NAME_PREFIX : threadNamePrefix);
        executor.setRejectedExecutionHandler(rejectedExecutionHandler == null ? DEFAULT_REJECTED_EXECUTION_HANDLER : rejectedExecutionHandler);
        return executor;
    }

    /**
     * 初始化线程池
     */
    private static void initialize(ThreadPoolTaskExecutor executor) {
        executor.initialize();
    }

    /**
     * 将主线程mdc里的信息同步给异步线程池里的线程
     *
     * @param executor 线程池
     */
    private static void setTaskDecorator(ThreadPoolTaskExecutor executor) {
        executor.setTaskDecorator(runnable -> {
            // main thread
            Map<String, String> mdcContext = MDC.getCopyOfContextMap();
            return () -> {
                try {
                    // task thread
                    if (null != mdcContext) {
                        MDC.setContextMap(mdcContext);
                    }
                    runnable.run();
                } finally {
                    MDC.clear();
                }
            };
        });
    }
}
